# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# JAX is Autograd and XLA

load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load(
    "@xla//third_party/py:py_import.bzl",
    "py_import",
)
load(
    "@xla//third_party/py:py_manylinux_compliance_test.bzl",
    "verify_manylinux_compliance_test",
)
load(
    "//jaxlib:jax.bzl",
    "PLATFORM_TAGS_DICT",
    "jax_py_test",
    "jax_wheel",
    "pytype_strict_library",
    "pytype_test",
)

licenses(["notice"])  # Apache 2

package(default_visibility = ["//visibility:public"])

exports_files(["wheel_size_test.py"])

genrule(
    name = "platform_tags_py",
    srcs = [],
    outs = ["platform_tags.py"],
    cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
)

pytype_strict_library(
    name = "build_utils",
    srcs = [
        "build_utils.py",
        ":platform_tags_py",
    ],
)

py_binary(
    name = "build_wheel",
    srcs = ["build_wheel.py"],
    data = [
        "LICENSE.txt",
        "//jaxlib",
        "//jaxlib:README.md",
        "//jaxlib:jaxlib_binaries",
        "//jaxlib:setup.py",
        "//jaxlib/xla:xla_client.py",
        "//jaxlib/xla:xla_extension",
        "@xla//xla/ffi/api:api.h",
        "@xla//xla/ffi/api:c_api.h",
        "@xla//xla/ffi/api:ffi.h",
    ],
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi_build//:pkg",
        "@pypi_setuptools//:pkg",
        "@pypi_wheel//:pkg",
    ],
)

jax_py_test(
    name = "build_wheel_test",
    srcs = ["build_wheel_test.py"],
    data = [":build_wheel"],
    deps = [
        "@bazel_tools//tools/python/runfiles",
    ],
)

cc_binary(
    name = "pjrt_c_api_gpu_plugin.so",
    linkopts = [
        "-Wl,--version-script,$(location :gpu_version_script.lds)",
        "-Wl,--no-undefined",
    ],
    linkshared = True,
    deps = [
        ":gpu_version_script.lds",
        "@xla//xla/pjrt/c:pjrt_c_api_gpu",
        "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
        "@xla//xla/service:gpu_plugin",
    ] + if_cuda([
        "//jaxlib/mosaic/gpu:custom_call",
        "@xla//xla/stream_executor:cuda_platform",
    ]) + if_rocm([
        "@xla//xla/stream_executor:rocm_platform",
    ]),
)

py_binary(
    name = "build_gpu_plugin_wheel",
    srcs = ["build_gpu_plugin_wheel.py"],
    data = [
        "LICENSE.txt",
        ":pjrt_c_api_gpu_plugin.so",
    ] + if_cuda([
        "//jaxlib:version",
        "//jaxlib/cuda:cuda_gpu_support",
        "//jax_plugins/cuda:pyproject.toml",
        "//jax_plugins/cuda:setup.py",
        "//jax_plugins/cuda:__init__.py",
        "@local_config_cuda//cuda:cuda-nvvm",
    ]) + if_rocm([
        "//jaxlib:version",
        "//jaxlib/rocm:rocm_gpu_support",
        "//jax_plugins/rocm:pyproject.toml",
        "//jax_plugins/rocm:setup.py",
        "//jax_plugins/rocm:__init__.py",
    ]),
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi_build//:pkg",
        "@pypi_setuptools//:pkg",
        "@pypi_wheel//:pkg",
    ],
)

py_binary(
    name = "build_gpu_kernels_wheel",
    srcs = ["build_gpu_kernels_wheel.py"],
    data = [
        "LICENSE.txt",
    ] + if_cuda([
        "//jaxlib:version",
        "//jaxlib/mosaic/gpu:mosaic_gpu",
        "//jaxlib/cuda:cuda_plugin_extension",
        "//jaxlib/cuda:cuda_gpu_support",
        "//jax_plugins/cuda:plugin_pyproject.toml",
        "//jax_plugins/cuda:plugin_setup.py",
        "@local_config_cuda//cuda:cuda-nvvm",
    ]) + if_rocm([
        "//jaxlib:version",
        "//jaxlib/rocm:rocm_plugin_extension",
        "//jaxlib/rocm:rocm_gpu_support",
        "//jax_plugins/rocm:plugin_pyproject.toml",
        "//jax_plugins/rocm:plugin_setup.py",
    ]),
    deps = [
        ":build_utils",
        "@bazel_tools//tools/python/runfiles",
        "@pypi_build//:pkg",
        "@pypi_setuptools//:pkg",
        "@pypi_wheel//:pkg",
    ],
)

selects.config_setting_group(
    name = "macos",
    match_any = [
        "@platforms//os:osx",
        "@platforms//os:macos",
    ],
)

selects.config_setting_group(
    name = "arm64",
    match_any = [
        "@platforms//cpu:aarch64",
        "@platforms//cpu:arm64",
    ],
)

selects.config_setting_group(
    name = "macos_arm64",
    match_all = [
        ":arm64",
        ":macos",
    ],
)

selects.config_setting_group(
    name = "macos_x86_64",
    match_all = [
        "@platforms//cpu:x86_64",
        ":macos",
    ],
)

selects.config_setting_group(
    name = "win_amd64",
    match_all = [
        "@platforms//cpu:x86_64",
        "@platforms//os:windows",
    ],
)

selects.config_setting_group(
    name = "linux_x86_64",
    match_all = [
        "@platforms//cpu:x86_64",
        "@platforms//os:linux",
    ],
)

selects.config_setting_group(
    name = "linux_aarch64",
    match_all = [
        ":arm64",
        "@platforms//os:linux",
    ],
)

string_flag(
    name = "jaxlib_git_hash",
    build_setting_default = "",
)

string_flag(
    name = "output_path",
    build_setting_default = "dist",
)

NVIDIA_WHEELS_DEPS = [
    "@pypi_nvidia_cublas_cu12//:whl",
    "@pypi_nvidia_cuda_cupti_cu12//:whl",
    "@pypi_nvidia_cuda_runtime_cu12//:whl",
    "@pypi_nvidia_cudnn_cu12//:whl",
    "@pypi_nvidia_cufft_cu12//:whl",
    "@pypi_nvidia_cusolver_cu12//:whl",
    "@pypi_nvidia_cusparse_cu12//:whl",
    "@pypi_nvidia_nccl_cu12//:whl",
    "@pypi_nvidia_nvjitlink_cu12//:whl",
]

jax_wheel(
    name = "jaxlib_wheel",
    no_abi = False,
    wheel_binary = ":build_wheel",
    wheel_name = "jaxlib",
)

py_import(
    name = "jaxlib_py_import",
    wheel = ":jaxlib_wheel",
)

jax_wheel(
    name = "jaxlib_wheel_editable",
    editable = True,
    wheel_binary = ":build_wheel",
    wheel_name = "jaxlib",
)

jax_wheel(
    name = "jax_cuda_plugin_wheel",
    enable_cuda = True,
    no_abi = False,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    wheel_binary = ":build_gpu_kernels_wheel",
    wheel_name = "jax_cuda12_plugin",
)

py_import(
    name = "jax_cuda_plugin_py_import",
    wheel = ":jax_cuda_plugin_wheel",
    wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)

jax_wheel(
    name = "jax_cuda_plugin_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    wheel_binary = ":build_gpu_kernels_wheel",
    wheel_name = "jax_cuda12_plugin",
)

jax_wheel(
    name = "jax_rocm_plugin_wheel",
    enable_rocm = True,
    no_abi = False,
    platform_version = "60",
    wheel_binary = ":build_gpu_kernels_wheel",
    wheel_name = "jax_rocm60_plugin",
)

jax_wheel(
    name = "jax_rocm_plugin_wheel_editable",
    editable = True,
    enable_rocm = True,
    platform_version = "60",
    wheel_binary = ":build_gpu_kernels_wheel",
    wheel_name = "jax_rocm60_plugin",
)

jax_wheel(
    name = "jax_cuda_pjrt_wheel",
    enable_cuda = True,
    no_abi = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    wheel_binary = ":build_gpu_plugin_wheel",
    wheel_name = "jax_cuda12_pjrt",
)

py_import(
    name = "jax_cuda_pjrt_py_import",
    wheel = ":jax_cuda_pjrt_wheel",
    wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)

jax_wheel(
    name = "jax_cuda_pjrt_wheel_editable",
    editable = True,
    enable_cuda = True,
    # TODO(b/371217563) May use hermetic cuda version here.
    platform_version = "12",
    wheel_binary = ":build_gpu_plugin_wheel",
    wheel_name = "jax_cuda12_pjrt",
)

jax_wheel(
    name = "jax_rocm_pjrt_wheel",
    enable_rocm = True,
    no_abi = True,
    platform_version = "60",
    wheel_binary = ":build_gpu_plugin_wheel",
    wheel_name = "jax_rocm60_pjrt",
)

jax_wheel(
    name = "jax_rocm_pjrt_wheel_editable",
    editable = True,
    enable_rocm = True,
    platform_version = "60",
    wheel_binary = ":build_gpu_plugin_wheel",
    wheel_name = "jax_rocm60_pjrt",
)

AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])

PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])

X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])

verify_manylinux_compliance_test(
    name = "jaxlib_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jaxlib_wheel",
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

verify_manylinux_compliance_test(
    name = "jax_cuda_plugin_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jax_cuda_plugin_wheel",
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

verify_manylinux_compliance_test(
    name = "jax_cuda_pjrt_manylinux_compliance_test",
    aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
    ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
    test_tags = [
        "manual",
    ],
    wheel = ":jax_cuda_pjrt_wheel",
    x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

pytype_test(
    name = "jaxlib_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jaxlib_wheel)",
        "--max-size-mib=110",
    ],
    data = [":jaxlib_wheel"],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)

pytype_test(
    name = "jax_cuda_plugin_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jax_cuda_plugin_wheel)",
        "--max-size-mib=20",
    ],
    data = [":jax_cuda_plugin_wheel"],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)

pytype_test(
    name = "jax_cuda_pjrt_wheel_size_test",
    srcs = [":wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jax_cuda_pjrt_wheel)",
        "--max-size-mib=120",
    ],
    data = [":jax_cuda_pjrt_wheel"],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)
